#!/usr/bin/python3
# -*- coding: utf-8 -*-   vim: set fileencoding=utf-8 :

"""Find cyclic Python imports."""


# To do:
# - Produce dot files (and instructions for running dot)

import argparse
import ast
from collections import defaultdict, namedtuple
from functools import total_ordering
import logging
import os
import re
import sys

from typing import (
    Dict,
    List,
    ItemsView,
    Iterable,
    Iterator,
    NoReturn,
    Optional,
    Set,
    Tuple,
)


KNOWN_EXTERNAL_MODULES = [
    "__future__",
    "astropy",
    "atexit",
    "code",
    "collections",
    "copy",
    "Cython",
    "datetime",
    "dill",
    "distutils",
    "doctest",
    "eqtools",
    "errno",
    "fileinput",
    "glob",
    "h5py",
    "idlpy",
    "importlib",
    "inspect",
    "itertools",
    "lic_internal",
    "math",
    "matplotlib",
    "mpi4py",
    "multiprocessing",
    "numpy",
    "os",
    "pen",
    "pexpect",
    "pickle",
    "plotly",
    "pylab",
    "re",
    "retrying",
    "scipy",
    "shutil",
    "socket",
    "string",
    "struct",
    "subprocess",
    "sys",
    "tempfile",
    "_thread",
    "thread",
    "time",
    "tqdm",
    "unittest",
    "vtk",
    "warnings",
    "weakref",
]


def main() -> None:
    args = parse_arguments()
    if args.recursive:
        python_files = list(find_python_files(args.path))
    else:
        python_files = args.path

    graph = DependencyDict()
    pencil_python_directory = os.path.join(
        find_pencil_home(python_files[0]), "python"
    )

    for python_file in python_files:
        deps = parse_imports(python_file, pencil_python_directory)
        if deps:
            graph[Module.from_file(python_file)] = deps

    # # Print all dependencies:
    # for file, deps in graph.items():
    #     print(
    #         "{} → [{}]".format(str(file), ", ".join(str(d) for d in deps))
    #     )
    logging.debug(graph)

    progress = ProgressIndicator(downsample_by=10000)

    # A cycle of length N will appear N times
    all_cycles = [
        cycle
        for node in graph
        for cycle in dfs(graph, node, node, progress)
    ]
    progress.finish()
    # … so we deduplicate:
    cycles = set([c.canonicalize() for c in all_cycles])

    if cycles:
        print_cycles(cycles)
        print("Found {} import cycles".format(len(cycles)))
        sys.exit(1)
    else:
        print("No import cycles found")


def print_cycles(cycles: Iterable["Cycle"]) -> None:
    grouped_cycles = _group_by_length(cycles)
    for count in sorted(grouped_cycles.keys()):
        print("Cycles of length {}:".format(count))
        for cycle in grouped_cycles[count]:
            print("  " + str(cycle))


def _group_by_length(
    cycles: Iterable["Cycle"]
) -> Dict[int, List["Cycle"]]:
    grouped_cycles = defaultdict(list)
    for cycle in cycles:
        grouped_cycles[len(cycle)].append(cycle)
    return grouped_cycles


def parse_imports(
    filename: str, pencil_python_directory: str
) -> Set["Module"]:
    directory, _ = os.path.split(os.path.realpath(filename))
    dependencies = set()  # type: Set[Module]
    with open(filename) as source:
        syntax_tree = ast.parse(source.read(), filename)
    imports, imports_from = extract_imports(syntax_tree)
    for i_node in imports:
        dependencies.update(
            Module.from_import_node(
                i_node, directory, filename, pencil_python_directory
            )
        )
    for if_node in imports_from:
        dependencies.update(
            Module.from_import_from_node(
                if_node, directory, filename, pencil_python_directory
            )
        )

    return dependencies


def extract_imports(
    syntax_tree: ast.Module
) -> Tuple[Iterator[ast.Import], Iterator[ast.ImportFrom]]:
    """Extract all Import and ImportFrom nodes from a Python AST.

    This walks the syntax tree recursively, so in addition to imports at
    the global module level, wa also get imports inside functions or
    conditional blocks.

    """
    analyzer = ImportExtractor()
    analyzer.visit(syntax_tree)
    return analyzer.get_imports(), analyzer.get_imports_from()


Line = namedtuple("Line", ["file", "number", "text"])


class ImportExtractor(ast.NodeVisitor):
    """An AST NodeVisitor that records all import/from...import entries."""

    def __init__(self) -> None:
        self.imports = set()  # type: Set[ast.Import]
        self.imports_from = set()  # type: Set[ast.ImportFrom]

    def visit_Import(self, node: ast.Import) -> None:
        self.imports.add(node)
        self.generic_visit(node)

    def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
        self.imports_from.add(node)
        self.generic_visit(node)

    def get_imports(self) -> Iterator[ast.Import]:
        yield from self.imports

    def get_imports_from(self) -> Iterator[ast.ImportFrom]:
        yield from self.imports_from


class Package:
    """A Python package.

    E.g.
      pencil.math.derivatives

    A package is always represented by a __init__.py file in the package
    directory.
    """

    def __init__(self, directory: str):
        self.directory = os.path.realpath(directory)
        if not (os.path.isdir(self.directory)):
            raise PackageDirectoryDoesNotExist(self.directory)

    def __eq__(self, other: object) -> bool:
        if isinstance(other, Package):
            return self.directory == other.directory
        else:
            return NotImplemented

    def __hash__(self) -> int:
        return hash(self.directory)

    def __repr__(self) -> str:
        return "Package {}".format(self.directory)


@total_ordering
class Module:
    """A Python module (essentially a .py file).

    For a package, point to the __init__.py file.
    """

    def __init__(self, package: Package, name: str) -> None:
        self.package = package
        self.name = name
        self.file = self.find_module_file(package, name)

    @staticmethod
    def find_module_file(package: Package, module_name: str) -> str:
        directory = package.directory
        for cand in [
            module_name + ".py",
            os.path.join(module_name, "__init__.py"),
        ]:
            candidate = os.path.join(directory, cand)
            if os.path.exists(candidate):
                return candidate
        raise Exception(
            "No module file found for package {}, module {}".format(
                package, module_name
            )
        )

    @staticmethod
    def from_file(filename: str) -> "Module":
        if not filename.endswith(".py"):
            die("File name {} has no .py suffix")
        package_dir, basename = os.path.split(
            re.sub(r"\.py$", "", filename)
        )
        return Module(Package(package_dir), basename)

    @staticmethod
    def from_import_node(
        node: ast.Import,
        directory: str,
        filename: str,
        pencil_python_directory: str,
    ) -> Iterator["Module"]:
        """Extract Module's from an "import ..." line.

        There are two cases:

        1. import pencil.module
        2. import some.external.module

        In the first case, we search for the imported file(s) under
        ${PENCIL_HOME}/python/[pencil/] .

        In the second case, we ignore the imports, as there cannot be
        cycles involving external packages. Before this, however, we
        verify that the top package from which we import is a known
        external package.

        """
        logging.debug(
            "from_import_node({}, {}, {})".format(
                _format_import_node(node),
                directory,
                pencil_python_directory,
            )
        )
        for alias in node.names:
            name = alias.name
            start = name.split(".")[0]
            if start == "pencil":
                yield Module._from_import(
                    None, name, pencil_python_directory
                )
            elif start in KNOWN_EXTERNAL_MODULES:
                # Absolute import of some non-pencil module
                logging.info(
                    "Non-interesting import line <import {}>".format(name)
                )
            else:
                logging.warning(
                    "Unknown global import <import {}>".format(name)
                )
                logging.warning(
                    "Consider adding {} to KNOWN_EXTERNAL_MODULES.".format(
                        start
                    )
                )

    @staticmethod
    def from_import_from_node(
        node: ast.ImportFrom,
        directory: str,
        filename: str,
        pencil_python_directory: str,
    ) -> Iterator["Module"]:
        """Extract Module's from a "from ... import ..." line.

        There are three cases to distinguish:

        1. from pencil.module import ...
        2. from ..relative.package.path import ...
        2a. from .. import ...
        3. from some.external.module import ...

        In the first case, we search for the imported file(s) under
        ${PENCIL_HOME}/python/[pencil/] .

        In the second case, we search for the imported file(s) relative
        from the directory of the current file.

        In the third case, there cannot be cycles.

        """
        logging.debug(
            "from_import_from_node({}, {}, {}, {})".format(
                _format_from_import_node(node),
                directory,
                pencil_python_directory,
                filename,
            )
        )
        for alias in node.names:
            name = alias.name
            import_from = node.module
            if import_from is None:
                # An import of the form
                # 'from . import this' or 'from .. import that'
                if node.level < 1:
                    die(
                        "Unexpected: no module, no level"
                        " in {}, {}:{}".format(
                            node, filename, node.lineno
                        )
                    )
                yield Module._from_import(
                    import_from, name, directory, node.level
                )
            else:
                if node.level == 0:
                    yield from Module._from_absolute_import(
                        import_from, name, pencil_python_directory
                    )
                else:
                    # Relative import of the form
                    # 'from .here import ...' or  'from ..here import ...'
                    yield Module._from_import(
                        import_from, name, directory, node.level
                    )

    @staticmethod
    def _from_absolute_import(
        import_from: str, name: str, pencil_python_directory: str
    ) -> Iterator["Module"]:
        """Absolute import of the form

        'from os.path import isdir'
        or 'from pencil.io import snapshot'

        """
        start = import_from.split(".")[0]
        if start == "pencil":
            # 'from pencil.module import ...'
            yield Module._from_import(
                import_from, name, pencil_python_directory
            )
        elif start in KNOWN_EXTERNAL_MODULES:
            # 'from some.external.module import ...'
            logging.info(
                "Non-interesting import <from {} import {}>".format(
                    import_from, name
                )
            )
        else:
            logging.warning(
                "Unknown global import <from {} import {}>".format(
                    import_from, name
                )
            )
            logging.warning(
                "Consider adding {} to KNOWN_EXTERNAL_MODULES.".format(
                    start
                )
            )

    @staticmethod
    def _from_import(
        import_from: Optional[str],
        importee: str,
        start_directory: str,
        level: int = 0,
    ) -> "Module":
        """Find a Module based on information from an import line.

        The Module represents either a 'module.py' file or a
        'module/__init__.py' file. Finding it is complicated by the fact
        that the import line can describe be an import of a module from a
        package, or the import of a symbol from a module.

        Examples of import lines are
          'import B',
          'from A import B',
          'from .A import B',
          'from ..A import B',
          'from . import B',
          'from .. import B'.

        Arguments:
          import_from :
              The package name given between the 'from' and the 'import
              keywords, without leading dots. In the examples above, "A".
          importee :
              The imported module or symbol. In the examples above, "B".
          start_directory :
              Directory from where to search for the module file.
          level :
              The number of leading dots prefixing the package name. In
              the examples above, the levels are (0, 0, 1, 2, 1, 2), in
              order.

        """
        logging.debug(
            "_from_import({}, {}, {}, {}".format(
                import_from, importee, start_directory, level
            )
        )
        if level > 0:
            subdir = os.path.join(".", *[".."] * (level - 1))
        else:
            subdir = "."
        if import_from is not None:
            subdir = os.path.join(subdir, *import_from.split("."))
        logging.debug(
            "_from_import: expanded {} → {}".format(import_from, subdir)
        )
        directory = os.path.realpath(
            os.path.join(start_directory, subdir)
        )
        parent_dir, last_part = os.path.split(directory)
        candidates = [(directory, importee), (parent_dir, last_part)]
        logging.debug(
            "Module._from_import(): candidates = {}".format(candidates)
        )
        for package_dir, module_name in candidates:
            candidate_file = os.path.join(
                package_dir, module_name + ".py"
            )
            init_file = os.path.join(
                package_dir, module_name, "__init__.py"
            )
            # NB: In Python, a symbol provided by __init__.py wins over a
            # module.py file. So for example 'from pencil.sim import
            # is_sim_dir' will import the function 'is_sim_dir' defined in
            # pencil/sim/__init__.py, not the module is_sim_dir
            # represented by the file pencil/sim/is_sim_dir.py .
            #
            # Here, however, we turn this logic upside down.
            # Reason: To see what 'from pencil.sim import is_sim_dir'
            # resolves to, we need to know the symbols loaded in
            # pencil/sim/__init__.py -- but I don't think we can without
            # executing the code.
            # If we prefer pencil/sim/is_sim_dir.py if it exists, we might
            # get false positives from our cycle check, which is better
            # than false negatives.
            # And, yes, we should stop defining 'function_name' in a module
            # 'function_name.py'.
            if os.path.exists(candidate_file):
                logging.debug("Found: {}".format(candidate_file))
                return Module(Package(package_dir), module_name)
            elif os.path.exists(init_file):
                logging.debug("Found: {}".format(init_file))
                return Module(Package(package_dir), module_name)
        raise Exception(
            "Cannot resolve 'from {} import {}'".format(
                import_from, importee
            )
        )

    def __lt__(self, other: "Module") -> bool:
        return str(self) < str(other)

    def __eq__(self, other: object) -> bool:
        if isinstance(other, Module):
            return self.file == other.file
        else:
            return NotImplemented

    def __hash__(self) -> int:
        return hash(self.file)

    def __repr__(self) -> str:
        return "Module {} in {}/".format(self.name, self.package)

    def __str__(self) -> str:
        """Pretty-print module path."""
        path = re.sub(r".*/python/pencil/", "pencil/", self.file)
        path = re.sub(r"/__init__.py$", "", path)
        return path


class Cycle:
    def __init__(self, module_sequence: List[Module]) -> None:
        """Arguments:

            module_sequence:
                The sequence of modules that forms a cycle. E.g. the cycle
                A → B → C → A would be represented as [A, B, C].

        """
        self.modules = module_sequence
        if len(self.modules) > 1 and self.modules[-1] == self.modules[0]:
            raise Exception(
                "Cycle() arguments must not repeat first Module as last entry"
            )

    def canonicalize(self) -> "Cycle":
        # Find index of "smallest" module:
        start_idx, _ = min(enumerate(self.modules), key=snd)
        return Cycle(self.modules[start_idx:] + self.modules[:start_idx])

    def __len__(self) -> int:
        return len(self.modules)

    def __iter__(self) -> Iterator[Module]:
        return iter(self.modules)

    def __eq__(self, other: object) -> bool:
        if isinstance(other, Cycle):
            # Assume that __str()__ encapsulates everything relevant for
            # comparison
            return str(self) == str(other)
        else:
            return NotImplemented

    def __hash__(self) -> int:
        return hash(str(self))

    def __str__(self) -> str:
        return "[" + ", ".join([str(m) for m in self.modules]) + "]"


class DependencyDict:
    """A Dict[Module, Set[Module]] that returns set() for missing entries.

      We cannot use defaultdict, as that leads to

      RuntimeError: dictionary changed size during iteration

      """

    def __init__(self) -> None:
        self.dict = dict()  # type: Dict[Module, Set[Module]]

    def __setitem__(self, key: Module, item: Set[Module]) -> None:
        if key not in self.dict:
            self.dict[key] = set()
        self.dict[key].update(item)

    def __getitem__(self, key: Module) -> Set[Module]:
        if key in self.dict:
            return self.dict[key]
        else:
            return set()

    def __iter__(self) -> Iterator[Module]:
        return iter(self.dict)

    def items(self) -> ItemsView[Module, Set[Module]]:
        return self.dict.items()


class ProgressIndicator:
    """A simple progress indicator for unspecified total work."""

    def __init__(self, downsample_by: int, line_length: int = 80) -> None:
        """
        Arguments:
            downsample_by :
                Only print a dot every N events.
            line_length :
                Print (up to) this many dots per line.

        """
        self.sample_by = 10000
        self.line_length = line_length
        self.count = self.sample_by
        self.column = 0
        self.line = 0
        self.line_indicator = "{}:".format(self.line)

    def tick(self) -> None:
        """Increase updat count (not necessarily visible)."""
        self.count -= 1
        if self.count < 1:
            self.display_tick()
            self.count += self.sample_by

    def display_tick(self) -> None:
        """Show update."""
        print(self.get_marker_char(), end="")
        self.column += 1
        if self.column >= self.line_length:
            print()
            self.column = 0
            self.line += 1
            self.line_indicator = "{}:".format(self.line)
        else:
            sys.stdout.flush()

    def get_marker_char(self) -> str:
        if self.line_indicator:
            marker, self.line_indicator = (
                self.line_indicator[0],
                self.line_indicator[1:],
            )
            return marker
        else:
            return "."

    def finish(self) -> None:
        """Clean up"""
        print()


def expand_dot_notation(name: str) -> str:
    """Expand notation like ..parent.sub to ../parent/sub ."""
    parts = []
    # The first leading dot points to the current directory
    if name.startswith("."):
        parts.append(".")
        name = name[1:]
    while name.startswith("."):
        parts.append(os.path.pardir)
        name = name[1:]
    parts.extend(name.split("."))
    return os.path.join(*parts)


def find_python_files(root_dirs: Iterable[str]) -> Iterator[str]:
    """Recursively list all .py files under the given root directories."""
    for root_dir in root_dirs:
        if not os.path.isdir(root_dir):
            die("Not a directory: {}".format(root_dir))
        for root, directories, file_names in os.walk(root_dir):
            for file_name in file_names:
                if file_name.endswith(".py"):
                    yield os.path.join(root, file_name)


def find_pencil_home(python_file: str) -> str:
    """Find $PENCIL_HOME/python."""
    path = os.path.realpath(python_file)
    while path:
        path, last = os.path.split(path)
        if _looks_like_pencil_home(path):
            return path
    die(
        "Could not find PENCIL_HOME searching upwards from {}".format(
            python_file
        )
    )


def _looks_like_pencil_home(path: str) -> bool:
    """Does the given path contain some files typical for PENCIL_HOME?"""
    subpaths = [
        "python/pencil",
        "samples",
        "license/GNU_public_license.txt",
    ]
    return all([os.path.exists(os.path.join(path, s)) for s in subpaths])


def dfs(
    graph: DependencyDict,
    start: Module,
    end: Module,
    progress: ProgressIndicator,
) -> Iterator[Cycle]:
    """Depth-first search algorithm for finding cycles in a graph.

    [From: https://stackoverflow.com/a/40834276]

    Arguments:
      graph : A dict {module => [included modules]}
      start : The start node for searching
      end : The end mode for searching

    """
    fringe = [(start, [])]  # type: List[Tuple[Module, List[Module]]]
    while fringe:
        progress.tick()
        state, path = fringe.pop()
        if path and state == end:
            yield Cycle(path)
            continue
        for next_state in graph[state]:
            if next_state in path:
                continue
            fringe.append((next_state, path + [next_state]))


def snd(tup: Tuple[int, Module]) -> Module:
    """A type-specialized version of Haskell's 'snd' function."""
    return tup[1]


def parse_arguments() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__, add_help=False)
    parser.add_argument(
        "-r",
        "--recursive",
        action="store_true",
        help="Recursively find all .py files under all given PATHs.",
    )
    parser.add_argument(
        "-h",
        "--help",
        action="store_true",
        help="Print this help text and exit.",
    )
    levels = ("DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL")
    parser.add_argument(
        "-L",
        "--log-level",
        action="store",
        default="WARNING",
        choices=levels,
        help="Set the log level (one of {})".format(", ".join(levels)),
    )
    parser.add_argument(
        "--debug", action="store_true", help="Set log-level to DEBUG."
    )
    parser.add_argument(
        "path",
        nargs="*",
        help="In non-recursive mode: Python file(s) to parse for imports."
        " With -r|--recursive: root directory/ies to search recursively.",
    )
    args = parser.parse_args()

    if args.debug:
        log_level = "DEBUG"
    else:
        log_level = args.log_level
    logging.basicConfig(
        level=log_level, format="%(levelname)s %(message)s"
    )

    if args.help:
        parser.print_help()
        print(
            """
Examples:
    cyclic-imports python/pencil/sim/*.py
    cyclic-imports python/pencil/{math,read}/*.py
    cyclic-imports --recursive python/pencil/
    cd $PENCIL_HOME/python; \\
        find pencil -name \\*.py | xargs utils/cyclic-imports"""
        )
        parser.exit()

    return args


def _format_import_node(node: ast.Import) -> str:
    return "Import[{}]".format([a.name for a in node.names])


def _format_from_import_node(node: ast.ImportFrom) -> str:
    return "ImportFrom[module={}, names={}, level={}]".format(
        node.module, [a.name for a in node.names], node.level
    )


def die(*msg_lines: str) -> NoReturn:
    print("\n".join(msg_lines))
    sys.exit(1)


class AstImportException(Exception):
    def __init__(
        self, statement: ast.stmt, filename: str, msg: str
    ) -> None:
        super().__init__(
            "{}:{} {}".format(filename, statement.lineno, msg)
        )


class PackageDirectoryDoesNotExist(Exception):
    """We have tried to treat a non-existing path as package directory"""

    def __init__(self, directory: str) -> None:
        super().__init__(
            "Package directory '{}' does not exist".format(directory)
        )


if __name__ == "__main__":
    main()
